Plotting Cross-Validated Predictions

展示如何使用 cross_val_predict 将预测 errors 可视化展示.


In [4]:
# %load ../common_import.py
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn import datasets

In [5]:
from sklearn.model_selection import cross_val_predict
from sklearn import linear_model

In [12]:
lr = linear_model.LinearRegression()
boston = datasets.load_boston()
# 将数据转成 DataFrame 方便查看
target = pd.DataFrame(boston.target)
data = pd.DataFrame(boston.data)

In [14]:
predicted = cross_val_predict(lr, data, target, cv=10)

In [21]:
fig, ax = plt.subplots()
ax.scatter(target, predicted, edgecolors=(0, 0, 0))
ax.plot([target.min(), target.max()], [target.min(), target.max()], 'k--', lw=4)
ax.set_xlabel('Measured')
ax.set_ylabel('Predicted')
plt.show()



In [25]:
target['predicted'] = predicted


---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/anaconda/envs/tflearn/lib/python3.5/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2441             try:
-> 2442                 return self._engine.get_loc(key)
   2443             except KeyError:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc (pandas/_libs/index.c:5280)()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc (pandas/_libs/index.c:5126)()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item (pandas/_libs/hashtable.c:20523)()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item (pandas/_libs/hashtable.c:20477)()

KeyError: '0'

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
<ipython-input-25-57218a5e4399> in <module>()
      1 target['predicted'] = predicted
----> 2 target['error'] = target['0'] - target['predicted']

~/anaconda/envs/tflearn/lib/python3.5/site-packages/pandas/core/frame.py in __getitem__(self, key)
   1962             return self._getitem_multilevel(key)
   1963         else:
-> 1964             return self._getitem_column(key)
   1965 
   1966     def _getitem_column(self, key):

~/anaconda/envs/tflearn/lib/python3.5/site-packages/pandas/core/frame.py in _getitem_column(self, key)
   1969         # get column
   1970         if self.columns.is_unique:
-> 1971             return self._get_item_cache(key)
   1972 
   1973         # duplicate columns & possible reduce dimensionality

~/anaconda/envs/tflearn/lib/python3.5/site-packages/pandas/core/generic.py in _get_item_cache(self, item)
   1643         res = cache.get(item)
   1644         if res is None:
-> 1645             values = self._data.get(item)
   1646             res = self._box_item_values(item, values)
   1647             cache[item] = res

~/anaconda/envs/tflearn/lib/python3.5/site-packages/pandas/core/internals.py in get(self, item, fastpath)
   3588 
   3589             if not isnull(item):
-> 3590                 loc = self.items.get_loc(item)
   3591             else:
   3592                 indexer = np.arange(len(self.items))[isnull(self.items)]

~/anaconda/envs/tflearn/lib/python3.5/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2442                 return self._engine.get_loc(key)
   2443             except KeyError:
-> 2444                 return self._engine.get_loc(self._maybe_cast_indexer(key))
   2445 
   2446         indexer = self.get_indexer([key], method=method, tolerance=tolerance)

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc (pandas/_libs/index.c:5280)()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc (pandas/_libs/index.c:5126)()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item (pandas/_libs/hashtable.c:20523)()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item (pandas/_libs/hashtable.c:20477)()

KeyError: '0'

In [38]:
target = target.rename(columns={0:'target'})

In [43]:
target['error'] = target['target'] - target['predicted']
target.head()


Out[43]:
target predicted error
0 24.0 30.053132 -6.053132
1 21.6 24.735976 -3.135976
2 34.7 30.364361 4.335639
3 33.4 28.321366 5.078634
4 36.2 27.545057 8.654943

In [44]:
fig, ax = plt.subplots()
ax.hist(target['error'])


Out[44]:
(array([   2.,   10.,  125.,  256.,   81.,   20.,    5.,    3.,    3.,    1.]),
 array([-20.27045931, -14.48997257,  -8.70948583,  -2.92899909,
          2.85148765,   8.63197439,  14.41246113,  20.19294787,
         25.97343461,  31.75392135,  37.53440809]),
 <a list of 10 Patch objects>)

In [ ]:
# TODO 数据 normalization